# -------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
# -------------------------------------------------------------

# Autogenerated By   : src/main/python/generator/generator.py
# Autogenerated From : scripts/builtin/autoencoder_2layer.dml

from typing import Dict, Iterable

from systemds.operator import OperationNode, Matrix, Frame, List, MultiReturn, Scalar
from systemds.utils.consts import VALID_INPUT_TYPES


def autoencoder_2layer(X: Matrix,
                       num_hidden1: int,
                       num_hidden2: int,
                       max_epochs: int,
                       **kwargs: Dict[str, VALID_INPUT_TYPES]):
    """
     Trains a 2-layer autoencoder with minibatch SGD and step-size decay.
     If invoked with H1 > H2 then it becomes a 'bowtie' structured autoencoder
     Weights are initialized using Glorot & Bengio (2010) AISTATS initialization.
     The script standardizes the input before training (can be turned off).
     Also, it randomly reshuffles rows before training.
     Currently, tanh is set to be the activation function. 
     By re-implementing 'func' DML-bodied function, one can change the activation.
    
    
    
    :param X: Filename where the input is stored
    :param num_hidden1: Number of neurons in the 1st hidden layer
    :param num_hidden2: Number of neurons in the 2nd hidden layer
    :param max_epochs: Number of epochs to train for
    :param full_obj: If TRUE, Computes objective function value (squared-loss)
        at the end of each epoch. Note that, computing the full
        objective can take a lot of time.
    :param batch_size: Mini-batch size (training parameter)
    :param step: Initial step size (training parameter)
    :param decay: Decays step size after each epoch (training parameter)
    :param mu: Momentum parameter (training parameter)
    :param W1_rand: Weights might be initialized via input matrices
    :param W2_rand: ---
    :param W3_rand: ---
    :param W4_rand: ---
    :return: Matrix storing weights between input layer and 1st hidden layer
    :return: Matrix storing bias between input layer and 1st hidden layer
    :return: Matrix storing weights between 1st hidden layer and 2nd hidden layer
    :return: Matrix storing bias between 1st hidden layer and 2nd hidden layer
    :return: Matrix storing weights between 2nd hidden layer and 3rd hidden layer
    :return: Matrix storing bias between 2nd hidden layer and 3rd hidden layer
    :return: Matrix storing weights between 3rd hidden layer and output layer
    :return: Matrix storing bias between 3rd hidden layer and output layer
    :return: Matrix storing the hidden (2nd) layer representation if needed
    """

    params_dict = {'X': X, 'num_hidden1': num_hidden1, 'num_hidden2': num_hidden2, 'max_epochs': max_epochs}
    params_dict.update(kwargs)
    
    vX_0 = Matrix(X.sds_context, '')
    vX_1 = Matrix(X.sds_context, '')
    vX_2 = Matrix(X.sds_context, '')
    vX_3 = Matrix(X.sds_context, '')
    vX_4 = Matrix(X.sds_context, '')
    vX_5 = Matrix(X.sds_context, '')
    vX_6 = Matrix(X.sds_context, '')
    vX_7 = Matrix(X.sds_context, '')
    vX_8 = Matrix(X.sds_context, '')
    output_nodes = [vX_0, vX_1, vX_2, vX_3, vX_4, vX_5, vX_6, vX_7, vX_8, ]

    op = MultiReturn(X.sds_context, 'autoencoder_2layer', output_nodes, named_input_nodes=params_dict)

    vX_0._unnamed_input_nodes = [op]
    vX_1._unnamed_input_nodes = [op]
    vX_2._unnamed_input_nodes = [op]
    vX_3._unnamed_input_nodes = [op]
    vX_4._unnamed_input_nodes = [op]
    vX_5._unnamed_input_nodes = [op]
    vX_6._unnamed_input_nodes = [op]
    vX_7._unnamed_input_nodes = [op]
    vX_8._unnamed_input_nodes = [op]

    return op
